#include <omp.h>
#include <math.h>
#include <iostream>
#include <fstream>
#include <sys/stat.h>
#include "include/Image.hpp"
#include "include/ImageWriter.hpp"
#include "include/skeleton.hpp"
#include "include/connected.hpp"
#include "afmm/include/safemalloc.hpp"
#include "include/messages.h"
#include "ImageStatistics.hpp"



string OUTPUT_FILE;

using namespace std;

/*************** CONSTRUCTORS ***************/
Image::Image(FIELD<float> *in, unsigned int islandThresh, float importanceThresh) {
    PRINT(MSG_NORMAL, "Creating Image Object...\n");
    this->numLayers = 256; /* Initially there are 256 layers (assuming 8 bit images) */
    this->layerThreshold = importanceThresh;
    this->islandThreshold = islandThresh;
    this->importance = NULL;
    this->im = in;
    this->nPix = in->dimX() * in->dimY();
    PRINT(MSG_NORMAL, "Done!\n");
}

Image::Image(FIELD<float> *in) {
    Image(in, 0, 0);
}

Image::~Image() {
    free(importance);
    delete im;
}

/*************** FUNCTIONS **************/

/*
 * Calculate the histogram of the image, which is equal to the importance for each level.
 * Avoid the use of in->value(), because it is less efficient (performs multiplications).
 * The order is irrelevant anyway.
 */
void Image::calculateImportance() {
    PRINT(MSG_NORMAL, "Calculating the importance for each layer...\n");
    int normFac = 0;
    float *c = im->data();
    float *end = im->data() + nPix;
    /* If importance was already calculated before, cleanup! */
    if (importance) free(importance);

    importance = (float *) SAFE_CALLOC(256 * sizeof (float));
    while (c < end)
        importance[(unsigned int) *(c++)] += 1;

    for (int i = 0; i < numLayers; ++i)
        normFac = (normFac < importance[i]) ? importance[i] : normFac;
    for (int i = 0; i < numLayers; ++i)
        importance[i] /= (double) normFac;
    PRINT(MSG_NORMAL, "Done!\n");

    for (int i = 0; i < 256; i++)
        PRINT(MSG_VERBOSE, "Level %d: %6.5f\n", i, importance[i]);
    PRINT(MSG_VERBOSE, "\n");
}


/**
 * fullDilate and fullErode are placeholders. Although they do work, they
 * should be replaced by better erode and dilation functions. This is used primarily
 * to test how much an opening on the skeleton will reduce the image size.
 */

/* fullDilate -- Perform dilation with a S.E. of 3x3, filled with ones. */
FIELD<float> * fullDilate(FIELD<float> *layer){
    FIELD<float> *ret = new FIELD<float>(layer->dimX(), layer->dimY());
    memset(ret->data(), 0, layer->dimX() * layer->dimY() * sizeof(float));
    for(int y=0; y<layer->dimY(); ++y){
        for(int x=0; x<layer->dimX(); ++x){
            if(layer->value(x,y)){
                ret->set(x-1, y-1, 255);
                ret->set(x-1, y  , 255);
                ret->set(x-1, y+1, 255);
                ret->set(x  , y-1, 255);
                ret->set(x  , y  , 255);
                ret->set(x  , y+1, 255);
                ret->set(x+1, y-1, 255);
                ret->set(x+1, y  , 255);
                ret->set(x+1, y+1, 255);
            }
        }
    }
    delete layer;
    return ret;
}

/* fullErode -- Perform erosion with a S.E. of 3x3, filled with ones. */
FIELD<float> * fullErode(FIELD<float> *layer){
    FIELD<float> *ret = new FIELD<float>(layer->dimX(), layer->dimY());
    for(int y=0; y<layer->dimY(); ++y){
        for(int x=0; x<layer->dimX(); ++x){
            if(
                layer->value(x-1,y-1) &&
                layer->value(x-1,y  ) &&
                layer->value(x-1,y+1) &&
                layer->value(x  ,y-1) &&
                layer->value(x  ,y  ) &&
                layer->value(x  ,y+1) &&
                layer->value(x+1,y-1) &&
                layer->value(x+1,y  ) &&
                layer->value(x+1,y+1) 
                
            ){
                ret->set(x,y,255);
            }else ret->set(x,y,0);
        }
    }
    delete layer;
    return ret;
}

/* rmObject -Remove current object in a 3x3 kernel, used for removeDoubleSkel: */
void rmObject(int *k, int x, int y){
    if(x < 0 || x > 2 || y < 0 || y > 2 || k[y*3+x]==0) return;
    k[y*3+x]=0;
    rmObject(k, x+1, y+1);
    rmObject(k, x+1, y  );
    rmObject(k, x+1, y-1);
    rmObject(k, x  , y+1);
    rmObject(k, x  , y-1);
    rmObject(k, x-1, y+1);
    rmObject(k, x-1, y  );
    rmObject(k, x-1, y-1);
}
/* numObjects - Count the number of objects in a 3x3 kernel, used for removeDoubleSkel: */
int numObjects(int *k){
    int c=0;
    for(int x=0; x< 3; x++){
        for(int y=0; y<3; ++y){
            if(k[y*3+x]){ rmObject(k, x, y); c++; }
        }
    }
    return c;
}
/* End count code */

/**
 * removeDoubleSkel
 * @param FIELD<float> * layer -- the layer of which the skeleton should be reduced
 * @return new FIELD<float> *. Copy of 'layer', where all redundant skeleton-points are removed (i.e. rows of width 2.)
 */
FIELD<float> * removeDoubleSkel(FIELD<float> *layer){
    //FIELD<float> *ret = new FIELD<float>(layer->dimX(), layer->dimY());
    int *k = (int *)calloc(9,sizeof(int));
    for(int y=0; y<layer->dimY(); ++y){
        for(int x=0; x<layer->dimX(); ++x){
            //ret->set(x,y,0);
            if(layer->value(x,y)){
                k[0] = layer->value(x-1,y-1);
                k[1] = layer->value(x-1,y  );
                k[2] = layer->value(x-1,y+1);
                k[3] = layer->value(x  ,y-1);
                k[4] = 0;
                k[5] = layer->value(x  ,y+1);
                k[6] = layer->value(x+1,y-1);
                k[7] = layer->value(x+1,y  );
                k[8] = layer->value(x+1,y+1);     
                if(k[0] + k[1] + k[2] + k[3] + k[4] + k[5] + k[6] + k[7] + k[8] > 256){
                    int b = numObjects(k);
                    if(b<2 ){layer->set(x,y,0); }
                }
            }
            
        }
    }
    free(k);
    return layer;
}


/**
 * Given a binary layer, remove all islands smaller than iThresh.
 * , where k is the current intensity.
 * @param layer
 * @param iThresh
 */
void Image::removeIslands(FIELD<float>*layer, unsigned int iThresh){
    int                     nPix    = layer->nx * layer->ny;
    ConnectedComponents     *CC     = new ConnectedComponents(255);
    int                     *ccaOut = new int[nPix];
    float                   *fdata  = layer->data();
    int                     highestLabel;
    unsigned int            *hist;

    /* CCA -- store highest label in 'max' -- Calculate histogram */
    highestLabel = CC->connected(fdata, ccaOut, layer->dimX(), layer->dimY(), std::equal_to<float>(), true);
    hist = new unsigned int[highestLabel+1];
    for (int j = 0; j < nPix; j++) hist[ccaOut[j]]++;

    /* Remove small islands */
    for (int j = 0; j < nPix; j++) {
        fdata[j] = (hist[ccaOut[j]] >= iThresh) ? fdata[j] : 255 - fdata[j];
    }

    /* Cleanup */
    delete [] hist;
    delete [] ccaOut;
    delete CC;

}


/*
 * Remove small islands according the the islandThreshold variable. Notice that both "on" and "off"
 * regions will be filtered.
 */
void Image::removeIslands() {
    int i, j, k;                    /* Iterators */
    FIELD<float> *inDuplicate = 0;  /* Duplicate, because of inplace modifications */
    FIELD<float> *newImg = new FIELD<float>(im->dimX(), im->dimY());
    int highestLabel;               /* for the CCA */
    int *ccaOut;                    /* labeled output */
    ConnectedComponents *CC;        /* CCA-object */
    float *fdata;
    unsigned int *hist;

    PRINT(MSG_NORMAL, "Removing small islands...\n");
    /* Some basic initialization */
    numLayers = 256;
    memset(newImg->data(), 0, nPix * sizeof(float));
    
    /* Connected Component Analysis */
#pragma omp parallel for private(i,j,k,ccaOut,CC,fdata,highestLabel,hist, inDuplicate)
    for (i = 0; i < 256; i++) {
        PRINT(MSG_VERBOSE, "Layer: %d\n", i);
        CC = new ConnectedComponents(255);
        ccaOut = new int[nPix];

        inDuplicate = (*im).dupe();
        inDuplicate->threshold((float) i);
        
        fdata = inDuplicate->data();

        /* CCA -- store highest label in 'max' -- Calculate histogram */
        highestLabel = CC->connected(fdata, ccaOut, im->dimX(), im->dimY(), std::equal_to<float>(), true);
        hist = (unsigned int *) SAFE_CALLOC((highestLabel + 1) * sizeof (unsigned int));
        for (j = 0; j < nPix; j++) hist[ccaOut[j]]++;

        /* Remove small islands */      
        for (j = 0; j < nPix; j++) {
            fdata[j] = (hist[ccaOut[j]] >= islandThreshold) ? fdata[j] : 255 - fdata[j];
        }
        
#pragma omp critical
        {
            for (j = 0; j < im->dimY(); j++)
                for (k = 0; k < im->dimX(); k++)
                    if (0 == fdata[j * im->dimX() + k] && newImg->fvalue(k,j) < i) newImg->set(k, j, i);
        }
        
        /* Cleanup */
        free(hist);
        delete [] ccaOut;
        delete CC;
        delete inDuplicate;
        
    }
        delete im;
        im = newImg;

    PRINT(MSG_NORMAL, "Done!\n");
}

/**
 * Remove unimportant layers -- Filter all layers for which their importance is lower than layerThreshold
 */
void Image::removeLayers() {
    float val;

    PRINT(MSG_NORMAL, "Filtering image layers...\n");
    PRINT(MSG_VERBOSE, "The following grayscale intensities are removed:\n");
    if (MSG_LEVEL == MSG_VERBOSE)
        for (int i = 0; i < 256; i++)
            if (importance[i] < layerThreshold)
                PRINT(MSG_VERBOSE, "(%d, %6.5f)\n", i, importance[i]);



    for (int y = 0; y < im->dimY(); y++) {
        for (int x = 0; x < im->dimX(); x++) {
            val = im->value(x, y);
            while (val > 0 && importance[(int) val] < layerThreshold) {
                val--;
            }
            im->set(x, y, val);
        }
    }
}


/**
 * Calculate the skeleton for each layer.
 */
void Image::computeSkeletons() {
    int i, lastLayer=0;
    FIELD<float> *imDupeCurr = 0;
    FIELD<float> *imDupePrev = 0;
    FIELD<float> *skelCurr=0;
    FIELD<float> *skelPrev=0;

    PRINT(MSG_NORMAL, "Creating ImageWriter object...\n");
    ImageWriter iw(OUTPUT_FILE.c_str());
    iw.writeHeader(im->nx, im->ny);
    PRINT(MSG_NORMAL, "Computing the skeleton for all layers...\n");
    imDupePrev = im->dupe();
    imDupePrev->threshold(0);
    skelPrev = computeSkeleton(imDupePrev, 0);
    for (i=1; i<256; i++) {
    PRINT(MSG_NORMAL, "Layer: %3d\r", i);
        if (importance[i] >= layerThreshold) {
            imDupeCurr = im->dupe();
            imDupeCurr->threshold(i);
           
            skelCurr = computeSkeleton(imDupeCurr, i);
            //skelCurr = fullDilate(skelCurr);
            //skelCurr = fullErode(skelCurr);
            //IS_analyseLayer("closing", skelCurr, i);
            skelCurr = removeDoubleSkel(skelCurr);
            iw.writeLayer(skelPrev, imDupePrev,(unsigned char) lastLayer);

            delete skelPrev;
            delete imDupePrev;

            skelPrev = skelCurr;
            imDupePrev = imDupeCurr;
            lastLayer = i;
       }
    }
    iw.writeLayer(skelPrev, imDupePrev,lastLayer);
    delete skelCurr;
    delete imDupeCurr;
    PRINT(MSG_NORMAL, "\n");
    PRINT(MSG_NORMAL, "Done!\n");
}
